﻿using System;
using System.Collections;
using System.Collections.Generic;
using System.IO;
using System.Linq;
using System.Reflection;
using System.Text;

namespace Infrastructure.Utils.Helpers
{
    /// <summary>
    /// Assembly帮助类
    /// </summary>
    public static partial class AssemblyHelper
    {
        #region Members 成员变量
        static Dictionary<Assembly, string> defaultNamespaces = new Dictionary<Assembly, string>();
        
        private static IEnumerable<Assembly> _assemblies;
        const int PublicKeyTokenBytesLength = 8;
        static Assembly entryAssembly;
        #endregion

        #region Construction 构造函数

        #endregion


        #region Properties 属性
        public static Assembly EntryAssembly
        {
            get
            {
                if (entryAssembly == null)
                    entryAssembly = Assembly.GetEntryAssembly();
                return entryAssembly;
            }
            set { entryAssembly = value; }
        }

        /// <summary>
        /// 获取所有的Assembly对象集
        /// </summary>
        public static IEnumerable<Assembly> Assemblies
        {
            get { return _assemblies ?? (_assemblies = GetLoadedAssemblies()); }
        }

        
        #endregion

        #region Methods
        static Assembly GetReflectionOnlyLoadedAssembly(string asmName)
        {
            try
            {
                return Assembly.ReflectionOnlyLoad(asmName);
            }
            catch
            {
                return null;
            }
        }

        static byte[] StringToBytes(string str)
        {
            int bytesLength = (int)(str.Length / 2);
            byte[] bytes = new byte[bytesLength];
            for (int i = 0; i < bytesLength; i++)
            {
                bytes[i] = byte.Parse(str.Substring(2 * i, 2), System.Globalization.NumberStyles.HexNumber);
            }
            return bytes;
        }

        
        public static string GetAssemblyFullName(string name, string version, System.Globalization.CultureInfo culture, string publicKeyToken)
        {
            AssemblyName asmName = new AssemblyName();
            asmName.Name = name;
            asmName.Version = new Version(version);
            asmName.CultureInfo = culture;
            if (publicKeyToken != null && publicKeyToken.Length == 2 * PublicKeyTokenBytesLength)
            {
                asmName.SetPublicKeyToken(StringToBytes(publicKeyToken));
            }
            return asmName.FullName;
        }


       

        /// <summary>
        /// 程序集全称中包含字符串
        /// </summary>
        /// <param name="assemblyFullName"></param>
        /// <param name="assemblyName"></param>
        /// <returns></returns>
        public static bool NameContains(string assemblyFullName, string assemblyName)
        {
            return AssertAssemblyName(assemblyFullName, assemblyName);
        }

        public static bool NameContains(Assembly assembly, string assemblyName)
        {
            return AssertAssemblyName(assembly.FullName, assemblyName);
        }
        public static bool NameContains(AssemblyName assembly, string assemblyName)
        {
            return AssertAssemblyName(assembly.FullName, assemblyName);
        }

        #region Attribute 特性
        /// <summary>
        /// 程序集是否包含某特性
        /// </summary>
        /// <param name="assemblyName">程序集名称</param>
        /// <param name="attributeType">特性类型</param>
        /// <returns></returns>
        public static bool HasAttribute(string assemblyName, Type attributeType)
        {
            return HasAttribute(GetLoadedAssembly(assemblyName), attributeType);
        }

        /// <summary>
        /// 程序集是否包含某特性
        /// </summary>
        /// <param name="assembly">程序集对象</param>
        /// <param name="attributeType">特性类型</param>
        /// <returns></returns>
        public static bool HasAttribute(Assembly assembly, Type attributeType)
        {
            if (assembly != null)
            {
                return Attribute.IsDefined(assembly, attributeType);
            }
            return false;
        }

        /// <summary>
        /// 类是否包含某特性
        /// </summary>
        /// <typeparam name="TAttribute">特性类型</typeparam>
        /// <param name="type">类型对象</param>
        /// <returns></returns>
        public static bool HasAttribute<TAttribute>(this Type type) where TAttribute : Attribute
        {
            return type.GetCustomAttributes(typeof(TAttribute), true).Length > 0;
        }

        /// <summary>
        /// 属性是否包含某特性
        /// </summary>
        /// <typeparam name="TAttribute">特性类型</typeparam>
        /// <param name="property">属性对象</param>
        /// <returns></returns>
        public static bool HasAttribute<TAttribute>(this PropertyInfo property) where TAttribute : Attribute
        {
            return property.GetCustomAttributes(typeof(TAttribute), true).Length > 0;
        }
        #endregion

        public static IEnumerable<Assembly> FindAssemblies(IEnumerable<string> assembliesName)
        {
            return GetLoadedAssemblies(assembliesName);
        }


        public static IEnumerable<Type> FindTypesByInterface<TInterface>(this Assembly assembly)
        {
            return FindTypesByInterface(assembly, typeof(TInterface));
        }

        public static IEnumerable<Type> FindTypesByInterface(this Assembly assembly, Type interfaceType)
        {
            return assembly.GetTypes().Where(type => type.GetInterface(interfaceType.Name, false) != null);
        }

        public static IEnumerable<Type> FindAllTypesByInterface<TInterface>(string[] assembliesName = null)
        {
            return FindAllTypesByInterface(typeof(TInterface), assembliesName);
        }

        public static IEnumerable<Type> FindAllTypesByInterface(Type interfaceType, string[] assembliesName = null)
        {
            var assemblies = GetLoadedAssemblies(assembliesName);
            return assemblies.SelectMany(assembly => assembly.FindTypesByInterface(interfaceType));
        }

        public static IEnumerable<Type> FindTypesByBaseType(this Assembly assembly, Type baseType)
        {
            return assembly.GetTypes().Where(type => type.IsSubclassOf(baseType));
        }

        public static IEnumerable<Type> FindAllTypesByBaseType(Type baseType, string[] assembliesName = null)
        {
            var assemblies = GetLoadedAssemblies(assembliesName);
            return assemblies.SelectMany(assembly => assembly.FindTypesByBaseType(baseType));
        }

      

        public static TAttribute GetAttribute<TAttribute>(this Type type) where TAttribute : Attribute
        {
            return type.GetCustomAttributes(typeof(TAttribute), true).Cast<TAttribute>().FirstOrDefault();
        }

        public static TAttribute GetAttribute<TAttribute>(this PropertyInfo property) where TAttribute : Attribute
        {
            return property.GetCustomAttributes(typeof(TAttribute), true).Cast<TAttribute>().FirstOrDefault();
        }

        /// <summary>
        /// 根据名称获取所有程序集
        /// </summary>
        /// <param name="assembliesName"></param>
        /// <returns></returns>
        public static IEnumerable<Assembly> GetLoadedAssemblies(IEnumerable<string> assembliesName=null)
        {
            if (assembliesName != null)
            {
                var loadedAssemblies = AppDomain.CurrentDomain.GetAssemblies().ToList();
                var referencedAssemlies = assembliesName.Select(n => new AssemblyName(n));

                var toLoad = referencedAssemlies.Where(r => !loadedAssemblies.Any(la => la.GetName().Name != r.Name)).ToList();
                toLoad.ForEach(path => loadedAssemblies.Add(AppDomain.CurrentDomain.Load(path.FullName)));
            }

            if (assembliesName != null)
                return AppDomain.CurrentDomain.GetAssemblies().Where(a => assembliesName.Contains(a.GetName().Name)).ToArray();
            return AppDomain.CurrentDomain.GetAssemblies();
        }

        #region 判断
        /// <summary>
        /// 判断程序集是否加载
        /// </summary>
        /// <param name="assemblyName">程序集名称</param>
        /// <returns></returns>
        public static bool IsLoadedAssembly(string assemblyName)
        {
            return GetLoadedAssembly(assemblyName) != null;
        }

        public static bool IsEntryAssembly(Assembly assembly)
        {
            Assembly entryAssembly = EntryAssembly;
            return entryAssembly == assembly;
        }
        public static bool IsEntryAssembly(string assemblyName)
        {
            Assembly entryAssembly = EntryAssembly;
            if (entryAssembly == null)
                return false;
            return NameContains(entryAssembly, assemblyName);
        }
        #endregion

        public static Assembly GetLoadedAssembly(string asmName)
        {
            IEnumerable assemblies = GetLoadedAssemblies();
            foreach (Assembly asm in assemblies)
            {
                if (PartialNameEquals(asm.FullName, asmName))
                    return asm;
            }
            return null;
        }

        public static Assembly GetAssembly(string assemblyFullName)
        {
            Assembly assembly = AssemblyHelper.GetLoadedAssembly(assemblyFullName);
            if (assembly != null) return assembly;
            return Assembly.Load(assemblyFullName);
        }

        /// <summary>
        /// 判断程序集全称中包含字符串
        /// </summary>
        /// <param name="fullName"></param>
        /// <param name="assemblyName"></param>
        /// <returns></returns>
        static bool AssertAssemblyName(string fullName, string assemblyName)
        {
            if (string.IsNullOrEmpty(assemblyName))
                return false;
            return fullName.ToLowerInvariant().Contains(assemblyName.ToLowerInvariant());
        }

        

        public static bool PartialNameEquals(string asmName0, string asmName1)
        {
            return string.Equals(GetPartialName(asmName0), GetPartialName(asmName1), StringComparison.InvariantCultureIgnoreCase);
        }

        public static string GetPartialName(string asmName)
        {
            int nameEnd = asmName.IndexOf(',');
            return nameEnd < 0 ? asmName : asmName.Remove(nameEnd);
        }

        public static string GetPartialName(Assembly assembly)
        {
            return GetPartialName(assembly.FullName);
        }

        public static string GetNamespace(Type type)
        {
            string typeName = type.FullName;
            int d = typeName.LastIndexOf('.');
            return d < 0 ? string.Empty : typeName.Remove(d);
        }
        public static string GetDefultNamespace(Assembly assembly)
        {
            string defaultNamespace = null;
            if (!defaultNamespaces.TryGetValue(assembly, out defaultNamespace))
            {
                defaultNamespace = GetDefultNamespaceCore(assembly);
                defaultNamespaces.Add(assembly, defaultNamespace);
            }
            return defaultNamespace;
        }
        public static string GetCommonPart(string[] strings, string[] excludedSuffixes)
        {
            List<string> filteredStrings = strings.Where(s => excludedSuffixes.Where(e => s.EndsWith(e, StringComparison.Ordinal)).FirstOrDefault() == null).ToList();
            if (filteredStrings.Count == 0) return string.Empty;
            StringBuilder commonPart = new StringBuilder(filteredStrings[0].Length);
            for (int i = 0; ; ++i)
            {
                char? c = null;
                foreach (string s in filteredStrings)
                {
                    if (i >= s.Length) return commonPart.ToString();
                    if (c == null)
                        c = s[i];
                    if (s[i] != c) return commonPart.ToString();
                }
                commonPart.Append(c);
            }
        }

        static string GetDefultNamespaceCore(Assembly assembly)
        {
            string[] names = assembly.GetManifestResourceNames();
            if (names.Length == 0) return string.Empty;
            if (names.Length == 1) return GetPartialName(assembly) + ".";
            string[] excludedSuffixes = new string[] { ".csdl", ".ssdl", ".msl" };
            return GetCommonPart(names, excludedSuffixes);
        }
        #endregion

    } 
}
